//===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "GPUOpsLowering.h"

#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;

LLVM::LLVMFuncOp mlir::getOrDefineFunction(Operation *moduleOp, Location loc,
                                           OpBuilder &b, StringRef name,
                                           LLVM::LLVMFunctionType type) {
  auto existing = dyn_cast_or_null<LLVM::LLVMFuncOp>(
      SymbolTable::lookupSymbolIn(moduleOp, name));
  if (existing)
    return existing;

  OpBuilder::InsertionGuard guard(b);
  b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
  return LLVM::LLVMFuncOp::create(b, loc, name, type, LLVM::Linkage::External);
}

static SmallString<16> getUniqueSymbolName(Operation *moduleOp,
                                           StringRef prefix) {
  // Get a unique global name.
  unsigned stringNumber = 0;
  SmallString<16> stringConstName;
  do {
    stringConstName.clear();
    (prefix + Twine(stringNumber++)).toStringRef(stringConstName);
  } while (SymbolTable::lookupSymbolIn(moduleOp, stringConstName));
  return stringConstName;
}

LLVM::GlobalOp mlir::getOrCreateStringConstant(OpBuilder &b, Location loc,
                                               Operation *moduleOp, Type llvmI8,
                                               StringRef namePrefix,
                                               StringRef str,
                                               uint64_t alignment,
                                               unsigned addrSpace) {
  llvm::SmallString<20> nullTermStr(str);
  nullTermStr.push_back('\0'); // Null terminate for C
  auto globalType =
      LLVM::LLVMArrayType::get(llvmI8, nullTermStr.size_in_bytes());
  StringAttr attr = b.getStringAttr(nullTermStr);

  // Try to find existing global.
  for (auto globalOp : moduleOp->getRegion(0).getOps<LLVM::GlobalOp>())
    if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
        globalOp.getValueAttr() == attr &&
        globalOp.getAlignment().value_or(0) == alignment &&
        globalOp.getAddrSpace() == addrSpace)
      return globalOp;

  // Not found: create new global.
  OpBuilder::InsertionGuard guard(b);
  b.setInsertionPointToStart(&moduleOp->getRegion(0).front());
  SmallString<16> name = getUniqueSymbolName(moduleOp, namePrefix);
  return LLVM::GlobalOp::create(b, loc, globalType,
                                /*isConstant=*/true, LLVM::Linkage::Internal,
                                name, attr, alignment, addrSpace);
}

LogicalResult
GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
  Location loc = gpuFuncOp.getLoc();

  SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
  if (encodeWorkgroupAttributionsAsArguments) {
    // Append an `llvm.ptr` argument to the function signature to encode
    // workgroup attributions.

    ArrayRef<BlockArgument> workgroupAttributions =
        gpuFuncOp.getWorkgroupAttributions();
    size_t numAttributions = workgroupAttributions.size();

    // Insert all arguments at the end.
    unsigned index = gpuFuncOp.getNumArguments();
    SmallVector<unsigned> argIndices(numAttributions, index);

    // New arguments will simply be `llvm.ptr` with the correct address space
    Type workgroupPtrType =
        rewriter.getType<LLVM::LLVMPointerType>(workgroupAddrSpace);
    SmallVector<Type> argTypes(numAttributions, workgroupPtrType);

    // Attributes: noalias, llvm.mlir.workgroup_attribution(<size>, <type>)
    std::array attrs{
        rewriter.getNamedAttr(LLVM::LLVMDialect::getNoAliasAttrName(),
                              rewriter.getUnitAttr()),
        rewriter.getNamedAttr(
            getDialect().getWorkgroupAttributionAttrHelper().getName(),
            rewriter.getUnitAttr()),
    };
    SmallVector<DictionaryAttr> argAttrs;
    for (BlockArgument attribution : workgroupAttributions) {
      auto attributionType = cast<MemRefType>(attribution.getType());
      IntegerAttr numElements =
          rewriter.getI64IntegerAttr(attributionType.getNumElements());
      Type llvmElementType =
          getTypeConverter()->convertType(attributionType.getElementType());
      if (!llvmElementType)
        return failure();
      TypeAttr type = TypeAttr::get(llvmElementType);
      attrs.back().setValue(
          rewriter.getAttr<LLVM::WorkgroupAttributionAttr>(numElements, type));
      argAttrs.push_back(rewriter.getDictionaryAttr(attrs));
    }

    // Location match function location
    SmallVector<Location> argLocs(numAttributions, gpuFuncOp.getLoc());

    // Perform signature modification
    rewriter.modifyOpInPlace(
        gpuFuncOp, [gpuFuncOp, &argIndices, &argTypes, &argAttrs, &argLocs]() {
          LogicalResult inserted =
              static_cast<FunctionOpInterface>(gpuFuncOp).insertArguments(
                  argIndices, argTypes, argAttrs, argLocs);
          (void)inserted;
          assert(succeeded(inserted) &&
                 "expected GPU funcs to support inserting any argument");
        });
  } else {
    workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
    for (auto [idx, attribution] :
         llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
      auto type = dyn_cast<MemRefType>(attribution.getType());
      assert(type && type.hasStaticShape() && "unexpected type in attribution");

      uint64_t numElements = type.getNumElements();

      auto elementType =
          cast<Type>(typeConverter->convertType(type.getElementType()));
      auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
      std::string name =
          std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
      uint64_t alignment = 0;
      if (auto alignAttr = dyn_cast_or_null<IntegerAttr>(
              gpuFuncOp.getWorkgroupAttributionAttr(
                  idx, LLVM::LLVMDialect::getAlignAttrName())))
        alignment = alignAttr.getInt();
      auto globalOp = LLVM::GlobalOp::create(
          rewriter, gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
          LLVM::Linkage::Internal, name, /*value=*/Attribute(), alignment,
          workgroupAddrSpace);
      workgroupBuffers.push_back(globalOp);
    }
  }

  // Remap proper input types.
  TypeConverter::SignatureConversion signatureConversion(
      gpuFuncOp.front().getNumArguments());

  Type funcType = getTypeConverter()->convertFunctionSignature(
      gpuFuncOp.getFunctionType(), /*isVariadic=*/false,
      getTypeConverter()->getOptions().useBarePtrCallConv, signatureConversion);
  if (!funcType) {
    return rewriter.notifyMatchFailure(gpuFuncOp, [&](Diagnostic &diag) {
      diag << "failed to convert function signature type for: "
           << gpuFuncOp.getFunctionType();
    });
  }

  // Create the new function operation. Only copy those attributes that are
  // not specific to function modeling.
  SmallVector<NamedAttribute, 4> attributes;
  ArrayAttr argAttrs;
  for (const auto &attr : gpuFuncOp->getAttrs()) {
    if (attr.getName() == SymbolTable::getSymbolAttrName() ||
        attr.getName() == gpuFuncOp.getFunctionTypeAttrName() ||
        attr.getName() ==
            gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName() ||
        attr.getName() == gpuFuncOp.getWorkgroupAttribAttrsAttrName() ||
        attr.getName() == gpuFuncOp.getPrivateAttribAttrsAttrName() ||
        attr.getName() == gpuFuncOp.getKnownBlockSizeAttrName() ||
        attr.getName() == gpuFuncOp.getKnownGridSizeAttrName())
      continue;
    if (attr.getName() == gpuFuncOp.getArgAttrsAttrName()) {
      argAttrs = gpuFuncOp.getArgAttrsAttr();
      continue;
    }
    attributes.push_back(attr);
  }

  DenseI32ArrayAttr knownBlockSize = gpuFuncOp.getKnownBlockSizeAttr();
  DenseI32ArrayAttr knownGridSize = gpuFuncOp.getKnownGridSizeAttr();
  // Ensure we don't lose information if the function is lowered before its
  // surrounding context.
  auto *gpuDialect = cast<gpu::GPUDialect>(gpuFuncOp->getDialect());
  if (knownBlockSize)
    attributes.emplace_back(gpuDialect->getKnownBlockSizeAttrHelper().getName(),
                            knownBlockSize);
  if (knownGridSize)
    attributes.emplace_back(gpuDialect->getKnownGridSizeAttrHelper().getName(),
                            knownGridSize);

  // Add a dialect specific kernel attribute in addition to GPU kernel
  // attribute. The former is necessary for further translation while the
  // latter is expected by gpu.launch_func.
  if (gpuFuncOp.isKernel()) {
    if (kernelAttributeName)
      attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
    // Set the dialect-specific block size attribute if there is one.
    if (kernelBlockSizeAttributeName && knownBlockSize) {
      attributes.emplace_back(kernelBlockSizeAttributeName, knownBlockSize);
    }
  }
  LLVM::CConv callingConvention = gpuFuncOp.isKernel()
                                      ? kernelCallingConvention
                                      : nonKernelCallingConvention;
  auto llvmFuncOp = LLVM::LLVMFuncOp::create(
      rewriter, gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
      LLVM::Linkage::External, /*dsoLocal=*/false, callingConvention,
      /*comdat=*/nullptr, attributes);

  {
    // Insert operations that correspond to converted workgroup and private
    // memory attributions to the body of the function. This must operate on
    // the original function, before the body region is inlined in the new
    // function to maintain the relation between block arguments and the
    // parent operation that assigns their semantics.
    OpBuilder::InsertionGuard guard(rewriter);

    // Rewrite workgroup memory attributions to addresses of global buffers.
    rewriter.setInsertionPointToStart(&gpuFuncOp.front());
    unsigned numProperArguments = gpuFuncOp.getNumArguments();

    if (encodeWorkgroupAttributionsAsArguments) {
      // Build a MemRefDescriptor with each of the arguments added above.

      unsigned numAttributions = gpuFuncOp.getNumWorkgroupAttributions();
      assert(numProperArguments >= numAttributions &&
             "Expecting attributions to be encoded as arguments already");

      // Arguments encoding workgroup attributions will be in positions
      // [numProperArguments, numProperArguments+numAttributions)
      ArrayRef<BlockArgument> attributionArguments =
          gpuFuncOp.getArguments().slice(numProperArguments - numAttributions,
                                         numAttributions);
      for (auto [idx, vals] : llvm::enumerate(llvm::zip_equal(
               gpuFuncOp.getWorkgroupAttributions(), attributionArguments))) {
        auto [attribution, arg] = vals;
        auto type = cast<MemRefType>(attribution.getType());

        // Arguments are of llvm.ptr type and attributions are of memref type:
        // we need to wrap them in memref descriptors.
        Value descr = MemRefDescriptor::fromStaticShape(
            rewriter, loc, *getTypeConverter(), type, arg);

        // And remap the arguments
        signatureConversion.remapInput(numProperArguments + idx, descr);
      }
    } else {
      for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
        auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
                                                  global.getAddrSpace());
        Value address = LLVM::AddressOfOp::create(rewriter, loc, ptrType,
                                                  global.getSymNameAttr());
        Value memory =
            LLVM::GEPOp::create(rewriter, loc, ptrType, global.getType(),
                                address, ArrayRef<LLVM::GEPArg>{0, 0});

        // Build a memref descriptor pointing to the buffer to plug with the
        // existing memref infrastructure. This may use more registers than
        // otherwise necessary given that memref sizes are fixed, but we can try
        // and canonicalize that away later.
        Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
        auto type = cast<MemRefType>(attribution.getType());
        Value descr = MemRefDescriptor::fromStaticShape(
            rewriter, loc, *getTypeConverter(), type, memory);
        signatureConversion.remapInput(numProperArguments + idx, descr);
      }
    }

    // Rewrite private memory attributions to alloca'ed buffers.
    unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
    auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
    for (const auto [idx, attribution] :
         llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
      auto type = cast<MemRefType>(attribution.getType());
      assert(type && type.hasStaticShape() && "unexpected type in attribution");

      // Explicitly drop memory space when lowering private memory
      // attributions since NVVM models it as `alloca`s in the default
      // memory space and does not support `alloca`s with addrspace(5).
      Type elementType = typeConverter->convertType(type.getElementType());
      auto ptrType =
          LLVM::LLVMPointerType::get(rewriter.getContext(), allocaAddrSpace);
      Value numElements = LLVM::ConstantOp::create(
          rewriter, gpuFuncOp.getLoc(), int64Ty, type.getNumElements());
      uint64_t alignment = 0;
      if (auto alignAttr =
              dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
                  idx, LLVM::LLVMDialect::getAlignAttrName())))
        alignment = alignAttr.getInt();
      Value allocated =
          LLVM::AllocaOp::create(rewriter, gpuFuncOp.getLoc(), ptrType,
                                 elementType, numElements, alignment);
      Value descr = MemRefDescriptor::fromStaticShape(
          rewriter, loc, *getTypeConverter(), type, allocated);
      signatureConversion.remapInput(
          numProperArguments + numWorkgroupAttributions + idx, descr);
    }
  }

  // Move the region to the new function, update the entry block signature.
  rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
                              llvmFuncOp.end());
  if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
                                         &signatureConversion)))
    return failure();

  // Get memref type from function arguments and set the noalias to
  // pointer arguments.
  for (const auto [idx, argTy] :
       llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
    auto remapping = signatureConversion.getInputMapping(idx);
    NamedAttrList argAttr =
        argAttrs ? cast<DictionaryAttr>(argAttrs[idx]) : NamedAttrList();
    auto copyAttribute = [&](StringRef attrName) {
      Attribute attr = argAttr.erase(attrName);
      if (!attr)
        return;
      for (size_t i = 0, e = remapping->size; i < e; ++i)
        llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
    };
    auto copyPointerAttribute = [&](StringRef attrName) {
      Attribute attr = argAttr.erase(attrName);

      if (!attr)
        return;
      if (remapping->size > 1 &&
          attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
        emitWarning(llvmFuncOp.getLoc(),
                    "Cannot copy noalias with non-bare pointers.\n");
        return;
      }
      for (size_t i = 0, e = remapping->size; i < e; ++i) {
        if (isa<LLVM::LLVMPointerType>(
                llvmFuncOp.getArgument(remapping->inputNo + i).getType())) {
          llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
        }
      }
    };

    if (argAttr.empty())
      continue;

    copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
    copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
    copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
    bool lowersToPointer = false;
    for (size_t i = 0, e = remapping->size; i < e; ++i) {
      lowersToPointer |= isa<LLVM::LLVMPointerType>(
          llvmFuncOp.getArgument(remapping->inputNo + i).getType());
    }

    if (lowersToPointer) {
      copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
      copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
      copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
      copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
      copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
      copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
      copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
      copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
      copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
      copyPointerAttribute(
          LLVM::LLVMDialect::getDereferenceableOrNullAttrName());
      copyPointerAttribute(
          LLVM::LLVMDialect::WorkgroupAttributionAttrHelper::getNameStr());
    }
  }
  rewriter.eraseOp(gpuFuncOp);
  return success();
}

LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
    gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = gpuPrintfOp->getLoc();

  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
  auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
  mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
  mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());

  Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
  if (!moduleOp)
    return rewriter.notifyMatchFailure(gpuPrintfOp,
                                       "Couldn't find a parent module");

  auto ocklBegin =
      getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
                          LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
  LLVM::LLVMFuncOp ocklAppendArgs;
  if (!adaptor.getArgs().empty()) {
    ocklAppendArgs = getOrDefineFunction(
        moduleOp, loc, rewriter, "__ockl_printf_append_args",
        LLVM::LLVMFunctionType::get(
            llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
                      llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
  }
  auto ocklAppendStringN = getOrDefineFunction(
      moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
      LLVM::LLVMFunctionType::get(
          llvmI64,
          {llvmI64, ptrType, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));

  /// Start the printf hostcall
  Value zeroI64 = LLVM::ConstantOp::create(rewriter, loc, llvmI64, 0);
  auto printfBeginCall =
      LLVM::CallOp::create(rewriter, loc, ocklBegin, zeroI64);
  Value printfDesc = printfBeginCall.getResult();

  // Create the global op or find an existing one.
  LLVM::GlobalOp global = getOrCreateStringConstant(
      rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat());

  // Get a pointer to the format string's first element and pass it to printf()
  Value globalPtr = LLVM::AddressOfOp::create(
      rewriter, loc,
      LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
      global.getSymNameAttr());
  Value stringStart =
      LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
                          globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
  Value stringLen = LLVM::ConstantOp::create(
      rewriter, loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());

  Value oneI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 1);
  Value zeroI32 = LLVM::ConstantOp::create(rewriter, loc, llvmI32, 0);

  auto appendFormatCall = LLVM::CallOp::create(
      rewriter, loc, ocklAppendStringN,
      ValueRange{printfDesc, stringStart, stringLen,
                 adaptor.getArgs().empty() ? oneI32 : zeroI32});
  printfDesc = appendFormatCall.getResult();

  // __ockl_printf_append_args takes 7 values per append call
  constexpr size_t argsPerAppend = 7;
  size_t nArgs = adaptor.getArgs().size();
  for (size_t group = 0; group < nArgs; group += argsPerAppend) {
    size_t bound = std::min(group + argsPerAppend, nArgs);
    size_t numArgsThisCall = bound - group;

    SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
    arguments.push_back(printfDesc);
    arguments.push_back(
        LLVM::ConstantOp::create(rewriter, loc, llvmI32, numArgsThisCall));
    for (size_t i = group; i < bound; ++i) {
      Value arg = adaptor.getArgs()[i];
      if (auto floatType = dyn_cast<FloatType>(arg.getType())) {
        if (!floatType.isF64())
          arg = LLVM::FPExtOp::create(
              rewriter, loc, typeConverter->convertType(rewriter.getF64Type()),
              arg);
        arg = LLVM::BitcastOp::create(rewriter, loc, llvmI64, arg);
      }
      if (arg.getType().getIntOrFloatBitWidth() != 64)
        arg = LLVM::ZExtOp::create(rewriter, loc, llvmI64, arg);

      arguments.push_back(arg);
    }
    // Pad out to 7 arguments since the hostcall always needs 7
    for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
      arguments.push_back(zeroI64);
    }

    auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
    arguments.push_back(isLast);
    auto call = LLVM::CallOp::create(rewriter, loc, ocklAppendArgs, arguments);
    printfDesc = call.getResult();
  }
  rewriter.eraseOp(gpuPrintfOp);
  return success();
}

LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
    gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = gpuPrintfOp->getLoc();

  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
  mlir::Type ptrType =
      LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);

  Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
  if (!moduleOp)
    return rewriter.notifyMatchFailure(gpuPrintfOp,
                                       "Couldn't find a parent module");

  auto printfType =
      LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
                                  /*isVarArg=*/true);
  LLVM::LLVMFuncOp printfDecl =
      getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);

  // Create the global op or find an existing one.
  LLVM::GlobalOp global = getOrCreateStringConstant(
      rewriter, loc, moduleOp, llvmI8, "printfFormat_", adaptor.getFormat(),
      /*alignment=*/0, addressSpace);

  // Get a pointer to the format string's first element
  Value globalPtr = LLVM::AddressOfOp::create(
      rewriter, loc,
      LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
      global.getSymNameAttr());
  Value stringStart =
      LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
                          globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});

  // Construct arguments and function call
  auto argsRange = adaptor.getArgs();
  SmallVector<Value, 4> printfArgs;
  printfArgs.reserve(argsRange.size() + 1);
  printfArgs.push_back(stringStart);
  printfArgs.append(argsRange.begin(), argsRange.end());

  LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
  rewriter.eraseOp(gpuPrintfOp);
  return success();
}

LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
    gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = gpuPrintfOp->getLoc();

  mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
  mlir::Type ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());

  Operation *moduleOp = gpuPrintfOp->getParentWithTrait<OpTrait::SymbolTable>();
  if (!moduleOp)
    return rewriter.notifyMatchFailure(gpuPrintfOp,
                                       "Couldn't find a parent module");

  // Create a valid global location removing any metadata attached to the
  // location as debug info metadata inside of a function cannot be used outside
  // of that function.
  Location globalLoc = loc->findInstanceOfOrUnknown<FileLineColLoc>();

  auto vprintfType =
      LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType, ptrType});
  LLVM::LLVMFuncOp vprintfDecl = getOrDefineFunction(
      moduleOp, globalLoc, rewriter, "vprintf", vprintfType);

  // Create the global op or find an existing one.
  LLVM::GlobalOp global =
      getOrCreateStringConstant(rewriter, globalLoc, moduleOp, llvmI8,
                                "printfFormat_", adaptor.getFormat());

  // Get a pointer to the format string's first element
  Value globalPtr = LLVM::AddressOfOp::create(rewriter, loc, global);
  Value stringStart =
      LLVM::GEPOp::create(rewriter, loc, ptrType, global.getGlobalType(),
                          globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
  SmallVector<Type> types;
  SmallVector<Value> args;
  // Promote and pack the arguments into a stack allocation.
  for (Value arg : adaptor.getArgs()) {
    Type type = arg.getType();
    Value promotedArg = arg;
    assert(type.isIntOrFloat());
    if (isa<FloatType>(type)) {
      type = rewriter.getF64Type();
      promotedArg = LLVM::FPExtOp::create(rewriter, loc, type, arg);
    }
    types.push_back(type);
    args.push_back(promotedArg);
  }
  Type structType =
      LLVM::LLVMStructType::getLiteral(gpuPrintfOp.getContext(), types);
  Value one = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
                                       rewriter.getIndexAttr(1));
  Value tempAlloc =
      LLVM::AllocaOp::create(rewriter, loc, ptrType, structType, one,
                             /*alignment=*/0);
  for (auto [index, arg] : llvm::enumerate(args)) {
    Value ptr = LLVM::GEPOp::create(
        rewriter, loc, ptrType, structType, tempAlloc,
        ArrayRef<LLVM::GEPArg>{0, static_cast<int32_t>(index)});
    LLVM::StoreOp::create(rewriter, loc, arg, ptr);
  }
  std::array<Value, 2> printfArgs = {stringStart, tempAlloc};

  LLVM::CallOp::create(rewriter, loc, vprintfDecl, printfArgs);
  rewriter.eraseOp(gpuPrintfOp);
  return success();
}

/// Helper for impl::scalarizeVectorOp. Scalarizes vectors to elements.
/// Used either directly (for ops on 1D vectors) or as the callback passed to
/// detail::handleMultidimensionalVectors (for ops on higher-rank vectors).
static Value scalarizeVectorOpHelper(Operation *op, ValueRange operands,
                                     Type llvm1DVectorTy,
                                     ConversionPatternRewriter &rewriter,
                                     const LLVMTypeConverter &converter) {
  TypeRange operandTypes(operands);
  VectorType vectorType = cast<VectorType>(llvm1DVectorTy);
  Location loc = op->getLoc();
  Value result = LLVM::PoisonOp::create(rewriter, loc, vectorType);
  Type indexType = converter.convertType(rewriter.getIndexType());
  StringAttr name = op->getName().getIdentifier();
  Type elementType = vectorType.getElementType();

  for (int64_t i = 0; i < vectorType.getNumElements(); ++i) {
    Value index = LLVM::ConstantOp::create(rewriter, loc, indexType, i);
    auto extractElement = [&](Value operand) -> Value {
      if (!isa<VectorType>(operand.getType()))
        return operand;
      return LLVM::ExtractElementOp::create(rewriter, loc, operand, index);
    };
    auto scalarOperands = llvm::map_to_vector(operands, extractElement);
    Operation *scalarOp =
        rewriter.create(loc, name, scalarOperands, elementType, op->getAttrs());
    result = LLVM::InsertElementOp::create(rewriter, loc, result,
                                           scalarOp->getResult(0), index);
  }
  return result;
}

/// Unrolls op to array/vector elements.
LogicalResult impl::scalarizeVectorOp(Operation *op, ValueRange operands,
                                      ConversionPatternRewriter &rewriter,
                                      const LLVMTypeConverter &converter) {
  TypeRange operandTypes(operands);
  if (llvm::any_of(operandTypes, llvm::IsaPred<VectorType>)) {
    VectorType vectorType =
        cast<VectorType>(converter.convertType(op->getResultTypes()[0]));
    rewriter.replaceOp(op, scalarizeVectorOpHelper(op, operands, vectorType,
                                                   rewriter, converter));
    return success();
  }

  if (llvm::any_of(operandTypes, llvm::IsaPred<LLVM::LLVMArrayType>)) {
    return LLVM::detail::handleMultidimensionalVectors(
        op, operands, converter,
        [&](Type llvm1DVectorTy, ValueRange operands) -> Value {
          return scalarizeVectorOpHelper(op, operands, llvm1DVectorTy, rewriter,
                                         converter);
        },
        rewriter);
  }

  return rewriter.notifyMatchFailure(op, "no llvm.array or vector to unroll");
}

static IntegerAttr wrapNumericMemorySpace(MLIRContext *ctx, unsigned space) {
  return IntegerAttr::get(IntegerType::get(ctx, 64), space);
}

/// Generates a symbol with 0-sized array type for dynamic shared memory usage,
/// or uses existing symbol.
static LLVM::GlobalOp getDynamicSharedMemorySymbol(
    ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp,
    gpu::DynamicSharedMemoryOp op, const LLVMTypeConverter *typeConverter,
    MemRefType memrefType, unsigned alignmentBit) {
  uint64_t alignmentByte = alignmentBit / memrefType.getElementTypeBitWidth();

  FailureOr<unsigned> addressSpace =
      typeConverter->getMemRefAddressSpace(memrefType);
  if (failed(addressSpace)) {
    op->emitError() << "conversion of memref memory space "
                    << memrefType.getMemorySpace()
                    << " to integer address space "
                       "failed. Consider adding memory space conversions.";
  }

  // Step 1. Collect symbol names of LLVM::GlobalOp Ops. Also if any of
  // LLVM::GlobalOp is suitable for shared memory, return it.
  llvm::StringSet<> existingGlobalNames;
  for (auto globalOp : moduleOp.getBody()->getOps<LLVM::GlobalOp>()) {
    existingGlobalNames.insert(globalOp.getSymName());
    if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
      if (globalOp.getAddrSpace() == addressSpace.value() &&
          arrayType.getNumElements() == 0 &&
          globalOp.getAlignment().value_or(0) == alignmentByte) {
        return globalOp;
      }
    }
  }

  // Step 2. Find a unique symbol name
  unsigned uniquingCounter = 0;
  SmallString<128> symName = SymbolTable::generateSymbolName<128>(
      "__dynamic_shmem_",
      [&](StringRef candidate) {
        return existingGlobalNames.contains(candidate);
      },
      uniquingCounter);

  // Step 3. Generate a global op
  OpBuilder::InsertionGuard guard(rewriter);
  rewriter.setInsertionPointToStart(moduleOp.getBody());

  auto zeroSizedArrayType = LLVM::LLVMArrayType::get(
      typeConverter->convertType(memrefType.getElementType()), 0);

  return LLVM::GlobalOp::create(rewriter, op->getLoc(), zeroSizedArrayType,
                                /*isConstant=*/false, LLVM::Linkage::Internal,
                                symName, /*value=*/Attribute(), alignmentByte,
                                addressSpace.value());
}

LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
    gpu::DynamicSharedMemoryOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  MemRefType memrefType = op.getResultMemref().getType();
  Type elementType = typeConverter->convertType(memrefType.getElementType());

  // Step 1: Generate a memref<0xi8> type
  MemRefLayoutAttrInterface layout = {};
  auto memrefType0sz =
      MemRefType::get({0}, elementType, layout, memrefType.getMemorySpace());

  // Step 2: Generate a global symbol or existing for the dynamic shared
  // memory with memref<0xi8> type
  auto moduleOp = op->getParentOfType<gpu::GPUModuleOp>();
  LLVM::GlobalOp shmemOp = getDynamicSharedMemorySymbol(
      rewriter, moduleOp, op, getTypeConverter(), memrefType0sz, alignmentBit);

  // Step 3. Get address of the global symbol
  OpBuilder::InsertionGuard guard(rewriter);
  rewriter.setInsertionPoint(op);
  auto basePtr = LLVM::AddressOfOp::create(rewriter, loc, shmemOp);
  Type baseType = basePtr->getResultTypes().front();

  // Step 4. Generate GEP using offsets
  SmallVector<LLVM::GEPArg> gepArgs = {0};
  Value shmemPtr = LLVM::GEPOp::create(rewriter, loc, baseType, elementType,
                                       basePtr, gepArgs);
  // Step 5. Create a memref descriptor
  SmallVector<Value> shape, strides;
  Value sizeBytes;
  getMemRefDescriptorSizes(loc, memrefType0sz, {}, rewriter, shape, strides,
                           sizeBytes);
  auto memRefDescriptor = this->createMemRefDescriptor(
      loc, memrefType0sz, shmemPtr, shmemPtr, shape, strides, rewriter);

  // Step 5. Replace the op with memref descriptor
  rewriter.replaceOp(op, {memRefDescriptor});
  return success();
}

LogicalResult GPUReturnOpLowering::matchAndRewrite(
    gpu::ReturnOp op, OpAdaptor adaptor,
    ConversionPatternRewriter &rewriter) const {
  Location loc = op.getLoc();
  unsigned numArguments = op.getNumOperands();
  SmallVector<Value, 4> updatedOperands;

  bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
  if (useBarePtrCallConv) {
    // For the bare-ptr calling convention, extract the aligned pointer to
    // be returned from the memref descriptor.
    for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
      Type oldTy = std::get<0>(it).getType();
      Value newOperand = std::get<1>(it);
      if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
                                        cast<BaseMemRefType>(oldTy))) {
        MemRefDescriptor memrefDesc(newOperand);
        newOperand = memrefDesc.allocatedPtr(rewriter, loc);
      } else if (isa<UnrankedMemRefType>(oldTy)) {
        // Unranked memref is not supported in the bare pointer calling
        // convention.
        return failure();
      }
      updatedOperands.push_back(newOperand);
    }
  } else {
    updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
    (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
                                  updatedOperands,
                                  /*toDynamic=*/true);
  }

  // If ReturnOp has 0 or 1 operand, create it and return immediately.
  if (numArguments <= 1) {
    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
        op, TypeRange(), updatedOperands, op->getAttrs());
    return success();
  }

  // Otherwise, we need to pack the arguments into an LLVM struct type before
  // returning.
  auto packedType = getTypeConverter()->packFunctionResults(
      op.getOperandTypes(), useBarePtrCallConv);
  if (!packedType) {
    return rewriter.notifyMatchFailure(op, "could not convert result types");
  }

  Value packed = LLVM::PoisonOp::create(rewriter, loc, packedType);
  for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
    packed = LLVM::InsertValueOp::create(rewriter, loc, packed, operand, idx);
  }
  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
                                              op->getAttrs());
  return success();
}

void mlir::populateGpuMemorySpaceAttributeConversions(
    TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
  typeConverter.addTypeAttributeConversion(
      [mapping](BaseMemRefType type, gpu::AddressSpaceAttr memorySpaceAttr) {
        gpu::AddressSpace memorySpace = memorySpaceAttr.getValue();
        unsigned addressSpace = mapping(memorySpace);
        return wrapNumericMemorySpace(memorySpaceAttr.getContext(),
                                      addressSpace);
      });
}
